package main

import (
	"context"
	"errors"
	"fmt"
	"log"
	"math/rand"
	"sync"
	"time"
)

type Task func()

type Selector interface {
	Select([]*WorkerInfo) (*WorkerInfo, error)
}

type NormalSelector struct{}

// 假设有4个服务器
// 服务器最大运行数量
// 当前运行数量
// 负载均衡
func (ns *NormalSelector) Select(workerInfos []*WorkerInfo) (*WorkerInfo, error) {
	minWorkLoad := 1.0
	var minW *WorkerInfo
	for _, wi := range workerInfos {
		wi.CountLock.RLock()
		if wl := float64(wi.CurrentTasksCount) / float64(wi.MaxTasksCount); wl < minWorkLoad {
			minWorkLoad = wl
			minW = wi
		}
		wi.CountLock.RUnlock()
	}
	if minW == nil {
		return nil, errors.New("worker not found!")
	}
	return minW, nil
}

// WorkerInfo Worker的信息
type WorkerInfo struct {
	ID                int          // 唯一的ID
	CurrentTasksCount int          // 当前的任务数
	MaxTasksCount     int          // 该 Worker 最多能承载的任务数
	CountLock         sync.RWMutex // 改变count时的锁
}

type LoadBalancer struct {
	workerInfos   []*WorkerInfo
	workerChannel map[int](chan Task)
	wg            sync.WaitGroup
	selector      Selector
	cancel        context.CancelFunc
}

func (lb *LoadBalancer) submit(task Task) {
	lb.wg.Add(1)
	worker, err := lb.selector.Select(lb.workerInfos)
	if err != nil {
		fmt.Println(err)
		// TODO: 直接panic  // 服务满了
		panic(err)
	}
	worker.CountLock.Lock()
	worker.CurrentTasksCount += 1
	worker.CountLock.Unlock()
	lb.workerChannel[worker.ID] <- task
}

func (lb *LoadBalancer) wait() {
	lb.wg.Wait()
	lb.cancel()
}

func LB(lbCount int) *LoadBalancer {
	ctx, cancel := context.WithCancel(context.Background())

	// 这里直接硬编码，使用一个Selector
	lb := &LoadBalancer{
		selector:      &NormalSelector{},
		workerChannel: make(map[int](chan Task)),
		cancel:        cancel,
	}

	// 启动 Worker协程
	// 慢启动？一次性全部启动？这里就是一次性全部启动
	for i := 0; i < lbCount; i++ {
		// 编号作为ID，最大工作数默认设置为 1, 2, 3, ....
		w := &WorkerInfo{ID: i, MaxTasksCount:10}
		lb.workerInfos = append(lb.workerInfos, w)
		ch := make(chan Task, 1)
		lb.workerChannel[i] = ch
		go func(w *WorkerInfo, task chan Task, wg *sync.WaitGroup, ctx context.Context) {
			for {
				select {
				case fn := <-task:
					// 这里为Worker的内容
					id := w.ID
					fmt.Printf("[BEGIN %d] worker %d got task\n", id, id)
					fn()
					fmt.Printf("[DONE %d] worker %d finish task\n", id, id)
					w.CountLock.Lock()
					w.CurrentTasksCount -= 1
					w.CountLock.Unlock()
					wg.Done()
				// 防止 goroutine泄漏
				case <-ctx.Done():
					return
				}
			}
		}(w, ch, &lb.wg, ctx)
	}

	return lb
}

func main() {
	lb := LB(4)

	for i := 0; i < 10; i++ {
		j := i
		lb.submit(func() {
			log.Printf("doing job %d\n", j)
			// 模拟任务耗时
			time.Sleep((time.Duration(rand.Intn(10)) * time.Second))
		})
	}
	lb.wait()
}